{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自主避障 - 模型训练\n",
    "\n",
    "欢迎来到主机端Jupyter Notebook！ 如果您浏览了在机器人上运行的笔记本，这应该看起来很熟悉。 在此笔记本中，我们将训练图像分类器来检测两个类别\n",
    "free和blocked，我们将使用它们来避免冲突。 为此，我们将使用流行的深度学习库* PyTorch *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载并提取数据集\n",
    "\n",
    "在开始之前，您应该上传在机器人上的data_collection.ipynb笔记本中创建的dataset.zip文件。\n",
    "\n",
    "然后，您应该通过调用以下命令来提取此数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -q dataset.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "您应该看到一个名为``dataset''的文件夹出现在文件浏览器中。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建数据集（dataset）实例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们使用torchvision.datasets包中可用的ImageFolder数据集类。 我们从``torchvision.transforms''包中附加转换，以准备训练数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.ImageFolder(\n",
    "    'dataset',\n",
    "    transforms.Compose([\n",
    "        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),\n",
    "        transforms.Resize((224, 224)),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "    ])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集分割成训练集和测试集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "接下来，我们将数据集分为* training *和* test *集。 测试集将用于验证我们训练的模型的准确性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建数据加载器以批量加载数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将创建两个“ DataLoader”实例，这些实例提供了用于对数据进行混排，生成*分批量*图像以及与多个工作程序并行加载样本的实用程序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义神经网络\n",
    "\n",
    "现在，我们定义将要训练的神经网络。 * torchvision *软件包提供了我们可以使用的预训练模型的集合。\n",
    "\n",
    "在一个称为“转移学习”的过程中，我们可以将预先训练的模型（针对数百万张图像进行训练）重新用于可能需要更少数据的新任务。\n",
    "\n",
    "在预训练模型的原始训练中学习到的重要功能可以重新用于新任务。 我们将使用“ alexnet”模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.alexnet(pretrained=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "``alexnet''模型最初是为具有1000个类别标签的数据集训练的，但是我们的数据集只有两个类别标签！ 我们将替换\n",
    "最后一层是一个新的未经训练的层，只有两个输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，我们转移模型以在GPU上执行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练神经网络\n",
    "\n",
    "使用下面的代码，我们将训练30个时期的神经网络，并在每个时期之后保存性能最佳的模型。\n",
    "\n",
    ">一个epoch是我们数据的完整记录。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 30\n",
    "BEST_MODEL_PATH = 'best_model.pth'\n",
    "best_accuracy = 0.0\n",
    "\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    for images, labels in iter(train_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(images)\n",
    "        loss = F.cross_entropy(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    test_error_count = 0.0\n",
    "    for images, labels in iter(test_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))\n",
    "    \n",
    "    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))\n",
    "    print('%d: %f' % (epoch, test_accuracy))\n",
    "    if test_accuracy > best_accuracy:\n",
    "        torch.save(model.state_dict(), BEST_MODEL_PATH)\n",
    "        best_accuracy = test_accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成后，您应该在Jupyter Lab文件浏览器中看到文件``best_model.pth''。 选择``右键单击''->``下载''以将模型下载到您的工作站"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
